
 #提交小说文档路径
# novel_path = './novels/Murder_on_the_Orient_Express_Novel.docx'
# novel_path = './novels/The_Time_Machine_Novel.docx'
# novel_path = './novels/Live_by_Night_Novel.docx'
# novel_path = './novels/Fight_Club_Novel.docx'
novel_path = './novels/Amadeus_novel.docx'

is_load = True # 是否断点续传（记得改下面这三个参数）
# 使用默认设置（不指定工作目录）
load_output_save_dir = './shot_play_scripts' #大目录
load_basename,load_unique_id = "",None # 默认会自动获取提交文件名、当前时间
# 实际输出目录为：'./shot_play_scripts/{time}_{filename}/***'

ablation = None #None,Refine,PlotGraph,Reference

from generate_modules.utils import print_config

# 【导入模型1】
from generate_modules.global_extractor import GlobalExtractor
from generate_modules.config import ConfigGlobalExtract,ConfigCallAPI
extractor1 = GlobalExtractor(ConfigGlobalExtract(
    model_config=ConfigCallAPI(),
    novel_path=novel_path,
    output_save_dir = load_output_save_dir,
    output_basename = load_basename,  # 如果为""，则自动生成（默认为提交文件名）
    unique_id = load_unique_id, # 如果为None，则自动生成（获取当前时间）
    load_data = is_load # 是否从文件里断点续传（默认为False）
))
# 【输出模型参数】
print_config(extractor1.config)


global_basename = extractor1.basename
global_unique_id = extractor1.unique_id

# 1.【提取全局要素】
print(f'\nextract_overall_elements:')
elements = extractor1.extract_overall_elements(
    chapter_nums = 2, chapter_windows = 4,
    refine_rounds = 4 if ablation != "Refine" else 0,
    ablation = ablation
    # temperature = 0.8
)
message = elements['message']
if message == 'Fail': raise Exception('Fail to extract_overall_elements.')
if message == 'Success': global_elements = elements['global_elements']

# plot_num = len(global_elements['merged_plots'])
# print(f"一共要生成 {plot_num} 集")


# 【导入模型2】
from generate_modules.rewriter import Rewriter
from generate_modules.config import ConfigScriptRewriter
extractor2 = Rewriter(ConfigScriptRewriter(
    model_config = ConfigCallAPI(),
    output_save_dir = load_output_save_dir,
    output_basename = global_basename,  # 如果为""，则自动生成（默认为提交文件名）
    unique_id = global_unique_id, # 如果为None，则自动生成（获取当前时间）
    load_data = is_load # 是否从文件里断点续传（默认为False）
))
# 【输出模型参数】
# print_config(extractor2.config)

# 2.【提取全局要素】
print(f'\nrewrite_scripts:')
elements = extractor2.script_rewriter(
    global_elements = global_elements,
    diagram_mode = 'Topo', # chapter,Topo,DFS
    screenplay_structure = None, #Three-Act Structure, Freytag’s Pyramid, Hero’s Journey, Four-Act Structure
    refine_rounds = 4 if ablation != "Refine" else 0,
    ablation = ablation
    # temperature = 0.8
)
message = elements['message']
if message == 'Fail': raise Exception('Fail to rewriter scripts.')
if message == 'Success': rewriter_elements = elements['rewriter_elements']